import re
from typing import List, Optional

import json
import requests
from modelscope_agent.tools.base import BaseTool, register_tool
from modelscope_agent.tools.utils.openapi_utils import get_parameter_value
from pydantic import BaseModel
from requests.exceptions import RequestException, Timeout

MAX_RETRY_TIMES = 3


class ParametersSchema(BaseModel):
    name: str
    description: str
    required: Optional[bool] = True
    type: str


class ToolSchema(BaseModel):
    name: str
    description: str
    parameters: List[ParametersSchema]


@register_tool('openapi_plugin')
class OpenAPIPluginTool(BaseTool):
    """
     openapi schema tool
    """
    name: str = 'api tool'
    description: str = 'This is a api tool that ...'
    parameters: list = []

    def __init__(self, cfg, name):
        super().__init__(cfg)
        self.name = name
        self.cfg = cfg.get(self.name, {})
        self.is_remote_tool = self.cfg.get('is_remote_tool', False)
        # remote call
        self.url = self.cfg.get('url', '')
        self.token = self.cfg.get('token', '')
        self.header = self.cfg.get('header', {})
        self.method = self.cfg.get('method', '')
        self.parameters = self.cfg.get('parameters', [])
        self.description = self.cfg.get('description',
                                        'This is a api tool that ...')
        self.responses_param = self.cfg.get('responses_param', [])
        super().__init__(cfg)

    def call(self, params: str, **kwargs):
        if self.url == '':
            raise ValueError(
                f"Could not use remote call for {self.name} since this tool doesn't have a remote endpoint"
            )

        json_params = json.loads(params)
        remote_parsed_input = json.dumps(
            self._remote_parse_input(**json_params))
        params = self._verify_args(params)
        if isinstance(params, str):
            return 'Parameter Error'

        path_params = {}
        cookies = {}
        for parameter in self.parameters:
            value = get_parameter_value(parameter, params)
            if parameter['in'] == 'path':
                path_params[parameter['name']] = value

            elif parameter['in'] == 'query':
                params[parameter['name']] = value

            elif parameter['in'] == 'cookie':
                cookies[parameter['name']] = value

            elif parameter['in'] == 'header':
                self.header[parameter['name']] = value

        for name, value in path_params.items():
            self.url = self.url.replace(f'{{{name}}}', f'{value}')

        # origin_result = None
        if self.method == 'POST' or self.method == 'DELETE':
            retry_times = MAX_RETRY_TIMES
            while retry_times:
                retry_times -= 1
                try:
                    print(f'data: {kwargs}')
                    print(f'header: {self.header}')
                    response = requests.request(
                        method=self.method,
                        url=self.url,
                        headers=self.header,
                        cookies=cookies,
                        data=remote_parsed_input)

                    if response.status_code != requests.codes.ok:
                        response.raise_for_status()
                    # origin_result = json.loads(
                    #     response.content.decode('utf-8'))
                    #
                    # final_result = self._parse_output(
                    #     origin_result, remote=True)
                    return response.content.decode('utf-8')
                except Timeout:
                    continue
                except RequestException as e:
                    raise ValueError(
                        f'Remote call failed with error code: {e.response.status_code},\
                        error message: {e.response.content.decode("utf-8")}')
                except Exception as e:
                    raise ValueError(e)

            raise ValueError(
                'Remote call max retry times exceeded! Please try to use local call.'
            )
        elif self.method == 'GET':
            retry_times = MAX_RETRY_TIMES

            new_url = self.url
            matches = re.findall(r'\{(.*?)\}', self.url)
            for match in matches:
                if match in params:
                    new_url = new_url.replace('{' + match + '}', params[match])
                else:
                    print(
                        f'The parameter {match} was not generated by the model.'
                    )

            while retry_times:
                retry_times -= 1
                try:
                    print('GET:', new_url)
                    print('GET:', self.url)

                    response = requests.request(
                        'GET', url=new_url, headers=self.header, params=params)
                    if response.status_code != requests.codes.ok:
                        response.raise_for_status()

                    # origin_result = json.loads(
                    #     response.content.decode('utf-8'))
                    #
                    # final_result = self._parse_output(
                    #     origin_result, remote=True)
                    return response.content.decode('utf-8')
                except Timeout:
                    continue
                except RequestException as e:
                    raise ValueError(
                        f'Remote call failed with error code: {e.response.status_code},\
                        error message: {e.response.content.decode("utf-8")}')

            raise ValueError(
                'Remote call max retry times exceeded! Please try to use local call.'
            )
        else:
            raise ValueError(
                'Remote call method is invalid!We have POST and GET method.')

    def _remote_parse_input(self, *args, **kwargs):
        restored_dict = {}
        for key, value in kwargs.items():
            if '.' in key:
                # Split keys by "." and create nested dictionary structures
                keys = key.split('.')
                temp_dict = restored_dict
                for k in keys[:-1]:
                    temp_dict = temp_dict.setdefault(k, {})
                temp_dict[keys[-1]] = value
            else:
                # f the key does not contain ".", directly store the key-value pair into restored_dict
                restored_dict[key] = value
            kwargs = restored_dict
        print('传给tool的参数：', kwargs)
        return kwargs


def parse_responses_parameters(param_name, param_info, parameters_list):
    param_type = param_info['type']
    param_description = param_info.get('description',
                                       f'调用api返回的{param_name}')  # 按需更改描述
    try:
        if param_type == 'object':
            properties = param_info.get('properties')
            if properties:
                # If the argument type is an object and has a non-empty "properties"
                # field, its internal properties are parsed recursively

                for inner_param_name, inner_param_info in properties.items():
                    param_type = inner_param_info['type']
                    param_description = inner_param_info.get(
                        'description',
                        f'调用api返回的{param_name}.{inner_param_name}')
                    parameters_list.append({
                        'name': f'{param_name}.{inner_param_name}',
                        'description': param_description,
                        'type': param_type,
                    })
        else:
            # Non-nested parameters are added directly to the parameter list
            parameters_list.append({
                'name': param_name,
                'description': param_description,
                'type': param_type,
            })
    except Exception as e:
        raise ValueError(f'{e}:schema结构出错')
